import warnings
import numpy as np
import os
import multiprocessing as mp
from functools import partial
from tslearn.metrics import dtw
import torch


def corrcoef(x, y=None, rowvar=True, bias=np._NoValue, ddof=np._NoValue, *,
             dtype=None):
    if bias is not np._NoValue or ddof is not np._NoValue:
        # 2015-03-15, 1.10
        warnings.warn('bias and ddof have no effect and are deprecated',
                      DeprecationWarning, stacklevel=3)
    c = np.cov(x, y, rowvar, dtype=dtype)

    try:
        d = np.diag(c)
    except ValueError:
        # scalar covariance
        # nan if incorrect value (nan, inf, 0), 1 otherwise
        return c / c
    stddev = np.sqrt(d.real)

    c /= stddev[:, None]
    c /= stddev[None, :]
    c = np.nan_to_num(c)

    # Clip real and imaginary parts to [-1, 1].  This does not guarantee
    # abs(a[i,j]) <= 1 for complex arrays, but is the best we can do without
    # excessive work.
    np.clip(c.real, -1, 1, out=c.real)
    if np.iscomplexobj(c):
        np.clip(c.imag, -1, 1, out=c.imag)
    return c


def concordance_correlation_coefficient(y_true, y_pred,
                                        sample_weight=None,
                                        multioutput='uniform_average'):
    # y_true.shape: (seq_len, dim); y_pred.shape: (seq_len, dim)
    if len(y_true.shape) > 1:
        ccc_list = []
        for i in range(y_true.shape[1]):  # dim==25
            cor = corrcoef(y_true[:, i], y_pred[:, i])[0][1]
            mean_true = np.mean(y_true[:, i])

            mean_pred = np.mean(y_pred[:, i])

            var_true = np.var(y_true[:, i])
            var_pred = np.var(y_pred[:, i])

            sd_true = np.std(y_true[:, i])
            sd_pred = np.std(y_pred[:, i])

            numerator = 2 * cor * sd_true * sd_pred

            denominator = var_true + var_pred + (mean_true - mean_pred) ** 2

            ccc = numerator / (denominator + 1e-8)

            ccc_list.append(ccc)
        ccc = np.mean(ccc_list)
    else:
        cor = np.corrcoef(y_true, y_pred)[0][1]
        mean_true = np.mean(y_true)
        mean_pred = np.mean(y_pred)

        var_true = np.var(y_true)
        var_pred = np.var(y_pred)

        sd_true = np.std(y_true)
        sd_pred = np.std(y_pred)

        numerator = 2 * cor * sd_true * sd_pred

        denominator = var_true + var_pred + (mean_true - mean_pred) ** 2
        ccc = numerator / (denominator + 1e-8)
    return ccc


def FRC_func(k_neighbour_matrix, k_pred, em=None):
    neighbour_index = np.argwhere(k_neighbour_matrix == 1).reshape(-1)
    neighbour_index_len = len(neighbour_index)
    max_ccc_sum = 0
    for i in range(k_pred.shape[0]):
        ccc_list = []
        for n_index in range(neighbour_index_len):
            emotion = em[neighbour_index[n_index]]
            ccc = concordance_correlation_coefficient(emotion, k_pred[i])
            ccc_list.append(ccc)
        max_ccc_sum += max(ccc_list)
    return max_ccc_sum


def FRD_func(k_neighbour_matrix, k_pred, em=None):
    neighbour_index = np.argwhere(k_neighbour_matrix == 1).reshape(-1)
    neighbour_index_len = len(neighbour_index)
    min_dwt_sum = 0
    for i in range(k_pred.shape[0]):
        dwt_list = []
        for n_index in range(neighbour_index_len):
            emotion = em[neighbour_index[n_index]]
            res = 0
            for st, ed, weight in [(0, 15, 1 / 15), (15, 17, 1), (17, 25, 1 / 8)]:
                res += weight * dtw(k_pred[i].astype(np.float32)[:, st: ed], emotion.astype(np.float32)[:, st: ed])
            dwt_list.append(res)
        min_dwt_sum += min(dwt_list)
    return min_dwt_sum


def compute_FRC_mp(dataset_path, pred, em, val_test='test', p=1):
    # pred: N 10 750 dim
    # em: N 750 dim
    if val_test == 'val':
        neighbour_matrix = np.load(os.path.join(dataset_path, 'person_specific_neighbour_emotion_val.npy'))
    else:
        neighbour_matrix = np.load(os.path.join(dataset_path, 'person_specific_neighbour_emotion_test.npy'))

    FRC_list = []
    with mp.Pool(processes=p) as pool:
        # use map
        _func_partial = partial(FRC_func, em=em.numpy())
        FRC_list += pool.starmap(_func_partial, zip(neighbour_matrix, pred.numpy()))
    return np.mean(FRC_list)


def compute_FRD_mp(dataset_path, pred, em, val_test='val', p=4):
    # pred: N 10 750 dim
    # speaker: N 750 dim

    if val_test == 'val':
        neighbour_matrix = np.load(os.path.join(dataset_path, 'person_specific_neighbour_emotion_val.npy'))
    else:
        neighbour_matrix = np.load(os.path.join(dataset_path, 'person_specific_neighbour_emotion_test.npy'))

    FRD_list = []
    with mp.Pool(processes=p) as pool:
        # use map
        _func_partial = partial(FRD_func, em=em.numpy())
        FRD_list += pool.starmap(_func_partial, zip(neighbour_matrix, pred.numpy()))

    return np.mean(FRD_list)


if __name__ == '__main__':

    dataset_path = ''  #
    emotion_file_dir = ['']  #

    with open('result.txt', 'a') as file:
        for e_file_dir in emotion_file_dir:
            print('Evaluating on ', e_file_dir, ':')

            all_listener_emotion_pred = np.load(
                os.path.join(e_file_dir, 'all_listener_emotion_pred.npy')
            )
            all_listener_emotion_pred = torch.from_numpy(all_listener_emotion_pred)

            listener_emotion_gt = np.load(
                os.path.join(e_file_dir, 'all_listener_emotion_gt.npy')
            )
            listener_emotion_gt = torch.from_numpy(listener_emotion_gt)

            # If you have problems running function compute_FRC_mp, please replace this function with function compute_FRC
            FRC = compute_FRC_mp(dataset_path, all_listener_emotion_pred, listener_emotion_gt, val_test='test', p=32)

            # If you have problems running function compute_FRD_mp, please replace this function with function compute_FRD
            FRD = compute_FRD_mp(dataset_path, all_listener_emotion_pred, listener_emotion_gt, val_test='test', p=32)

            output_string = "Metric: | FRC: {:.4f} | FRD: {:.4f}\n".format(FRC, FRD)
            file.write(output_string)

            print("Metric: | FRC: {:.4f} | FRD: {:.4f} ".format(FRC, FRD))
